from Numeric import * 
from Tkinter import *
from visual import *
from random import random
import time,sys,thread


##  Simulation class definition
class Simulation:
    def __init__(self,parent):
        self.parent = parent

        self.unitCellsX = 4
        self.unitCellsY = 8

        self.unitCellSizeX = 4
        self.unitCellSizeY = 2
        
        self.unitCellSpinsX = 3
        self.unitCellSpinsY = 2

        self.domainWallTypeX = 1        # none=0, bond=1, site=2
        self.domainWallTypeY = 0
        self.domainWallExchX = -1        # antiferro=-1, ferro=1
        
        self.nUnitCells = self.unitCellsX*self.unitCellsY
        self.nUnitCellSpins = self.unitCellSpinsX*self.unitCellSpinsY
        self.nTotSpins = self.unitCellSpinsX*self.unitCellSpinsY*self.unitCellsX*self.unitCellsY

        self.dx = 1.0
        self.dy = 1.0
        
        self.dt = 0.005
        self.paused = 1
        self.pause_it = 0

        self.Ja = -1
        self.Jb = -.4
        
        self.k = 2*pi*vector((self.unitCellSpinsX/self.unitCellSizeX)/(self.unitCellsX*self.unitCellSizeX*self.dx),(self.unitCellSpinsY/self.unitCellSizeY)/(self.unitCellsY*self.unitCellSizeY*self.dy),0)                             # 2pi/L * (kx integer,ky integer)
        # this has been modified so that using cos(k.s.pos) produces waves commensurate with the magnetic unit cell

        self.spin_length = 1.0
        self.baseSpinSigma = .001
        
        self.spinvec_length = hypot(self.dx,self.dy)*.35

        self.toggle_sublattices = 0
        self.toggle_torques = BooleanVar()
        self.toggle_circles = BooleanVar()
        self.toggle_colors = BooleanVar()
        self.toggle_polarization = BooleanVar()
        self.toggle_norm_S = BooleanVar()
        self.toggle_fix_sZ = BooleanVar()

        self.setup_scene()
        self.setSpinMagRatios(self.k)
        
        self.lattice = zeros((self.unitCellsX,self.unitCellsY,self.unitCellSizeX,self.unitCellSizeY))
        
##        self.spins = self.create_lattice(self.spins, self.Ntotx, self.Ntoty, self.dx, self.dy)
##        self.spins = self.initialize_spins(self.spins, self.k)
##        self.pause_it = 0
##        self.spins = self.initialize_circles(self.spins)
##        thread.start_new_thread(self.animate_spins,(self.spins,))

#        print a[][]




        print z



##                        
##        for m in range(self.unitCellsX):        # m starts at '0'
##            a[m] =
##
##            
##            for n in range(self.unitCellsY):        # n starts at '0'
##            a[m][n] =
##
##            
##            for i in range(self.unitCellsX):        # i starts at '0'
##            a[m][n][i] =
##
##            
##            for j in range(self.unitCellsX):        # j starts at '0'
##            a[m][n][i][j] = m+n+i+j
##            
##            ))))

        
    def setup_scene(self):
        scene.title = "Simulation"
        scene.width = 580
        scene.height = 500
        scene.x = 20
        scene.y = 20
        scene.autoscale = 0
        scene.up = (0,0,1)
        scene.forward = (0,.1,-5)
        scene.lights = [vector(-.5,.5,.5), vector(.5,-.5,.5)]
        scene.background = (0,0,0)
        scene.uniform = 1
        scene.range = 2*(self.unitCellsX*self.unitCellSizeX)**(.75)

        scene.select()


    def setSpinMagRatios(self, _k=(1/8.,1/8.,1/8.)):
#        self.spin_mag_ratio = (4-(16-4*(cos(_k.x*self.dx)+cos(_k.y*self.dy))**2)**(.5))/(2*(cos(_k.x*self.dx)+cos(_k.y*self.dy)))


        self.spin_mag_ratio = -(-abs(self.Ja+self.Jb)+2.0*(self.Ja*self.Jb*(sin(1.5*_k.x*self.dx))**2)**(.5))/((self.Ja**2+self.Jb**2+2.0*self.Ja*self.Jb*cos(3.0*_k.x*self.dx))**(.5))

        print 'spin up to down magnitude ratio', self.spin_mag_ratio


                

#       This part needs to instead include a way of setting up the grouping of objects.  I suggest that you start with the unit cell,
#       and then make a list of references to duplicate copies of the unit cell, tiled in the plane, with boundarys that wrap around.
#       This shouldnt be too difficult.  Remember, there are three cases.  Spins on the boundary of a unit cell, which must in some way
#       reference the next spin over, in the next unit cell.  This requires referencesing it somehow.  I suggest that we reference
#       them by a elements of a multidimensional list.  If the list on one tier has a certain attribute: namely, if its index number is
#       equal to a number that is the min or max of the size of the unit cell, or the min or max of the lattice, gets a wrap around in
#       that reference frame.  Both a wrap around from the unit cell and the lattice occur at the edge of the lattice.  Only the wrap
#       around from the unit cell occurs for boundarys interior to the boundary of the lattice.  This is like a fractal, or just two
#       layers of it.  We need to reference the layers in the same way but on different scales.  Once this case by case indexing of
#       neighbors occurs we should be complete, because we will have generated two lists, one with references to its neighbors,
#       and one with the additional coupling information.  The two list will correspond.  Torques will now be calculated by going unit
#       cell to unit cell, spin to spin, adding up their torques and replacing a temp variable for torque (reset every frame)
#       for the next update, where we merely use this new torque information like before.  Setting up the lattice will also require
#       this type of mimicry.  You will have to add an additional amount to the position, generated by the the spin number mod the unit
#       cell length in that direction.  However, we will also have to rearrage the way that site and bond centered graphics appear.
#       This is just filling in the holes of the unit cell that dont have spins.  I will for now only do vertical stripes.  Therefore,
#       we will need to add a circle or line every certain number of spacings.  This corresponds to a position located at the 'end' of
#       the unit cell.  if it is 4 long, its in position 4, whatever that index is.  Jb and Ja automatically generated as from mathching
#       indicies to the rule spoken about before.  Generation of initial condition: USE phase angle given by mathematica, spin by spin
#       (keep adding up rotations on cos and sin as per the move from spin to spin and cell to cell, s1 = 0 deg.).  This information
#       should be stored in an array for the unit cell, the two pieces of information being the spin magnitude and its direction.  This
#       information is then initialized in the spins by the cos and sin function process mentioned before.
#       The method will commonly reduce to a list of lists, which in turn have elements which are lists of lists.  Matricies within matricies.

#               Like looking through a very special kind of crystal.

        ## Phase shift when linking between spins on lattice boundary, all spin positions,  This will give a Wizard of Oz situation, where one unit cell does all the work of the lattice.

    def create_lattice(self,_lattice=[[[[]]]]):

        for m in range(self.unitCellsX):        # m+1 starts at '1'

            for n in range(self.unitCellsY):        # n+1 starts at '1'

                for i in range(self.unitCellsX):        # i+1 starts at '1'

                    for j in range(self.unitCellsY):        # j+1 starts at '1'

                        _lattice[[[[m,n,i,j]]]] = (m+1)+(n+1)+(i+1)+(j+1)

                        print blah

        return _lattice
                        


##    def create_lattice(self, _lattice=[], _Ntotx=8, _Ntoty=8, _deltax=1.0, _deltay=1.0):
##        self.xmin = -_Ntotx*_deltax/2.
##        self.ymin = -_Ntoty*_deltay/2.
##        self.nz = 0
##
##        for ny in range(_Ntoty):
##            y = self.ymin + ny*_deltay
##            for nx in range(_Ntotx):
##                x = self.xmin + nx*_deltax + (nx/2)*_deltax
##
##                _lattice.append(frame())
##                _lattice[-1].pos = vector(x,y,0)
##                _lattice[-1].spin = vector()
##                _lattice[-1].spinvec = arrow(pos=(x,y,0), axis=(0,0,1), color=(.9,.9,.9), shaftwidth=0.4)
##              
##                _lattice[-1].torque = vector()
##                _lattice[-1].torquevec = arrow(pos=(x,y,0), axis=(0,0,1), color=(.5,.5,1), shaftwidth=0.05)
##
##                _lattice[-1].circle = curve()
##
##                _lattice[-1].nearx = range(2)
##                _lattice[-1].neary = range(2)
##                _lattice[-1].indices = (nx,ny,1-(nx+ny)%2)
##              
##        for s in _lattice:
##              nx, ny, nz = s.indices
##              if nx == 0:                           # leftmost spin in a row
##                  nspinl = _Ntotx*ny + _Ntotx-1     # wrap around to spin on right side
##                  s.nearx[0] = nspinl               # reference this by its list element number, given by the order in which the spins are added to the list...
##              else:
##                  nspinl = _Ntotx*ny + nx-1
##                  s.nearx[0] = nspinl
##              if nx == _Ntotx-1:                    # rightmost spin in a row
##                  nspinr = _Ntotx*ny
##                  s.nearx[1] = nspinr
##              else:
##                  nspinr = _Ntotx*ny + nx+1 
##                  s.nearx[1] = nspinr
##
##              if ny == 0:                           # bottom spin in a column
##                  nspind = _Ntotx*(_Ntoty-1) + nx 
##                  s.neary[0] = nspind
##              else:
##                  nspind = _Ntotx*(ny-1) + nx
##                  s.neary[0] = nspind
##              if ny == _Ntoty-1:                    # top spin in a column
##                  nspinu = nx
##                  s.neary[1] = nspinu
##              else:
##                  nspinu = _Ntotx*(ny+1) + nx
##                  s.neary[1] = nspinu
##
##        for lnx in range(_Ntotx/2+1):
##            startx = self.xmin + lnx*_deltax*3. - _deltax
##            circle = curve(x=startx + self.spinvec_length*.3*cos(arange(0,2*pi,.31)), y=-.5 + self.spinvec_length*.3*sin(arange(0,2*pi,.31)), color=(.5,.5,.5))
##
##        return _lattice



    def initialize_spins(self, _lattice=[[[[]]]], _k=(1/8.,1/8.,1/8,)):
        self.pause_it = 1
        while self.paused == 0:
            pass
       
        was_split = 0
        if self.toggle_sublattices == 1:
            was_split = 1
            self.t_sublattices()
      
        # if any spin is over magnitude of 1: renormalize all spins to that spin, including itsself.

        # spin 

        
        spinA_zmag = (self.spin_length**2-self.spinA_sigma**2)**(.5)
        spinB_zmag = (self.spin_length**2-self.spinB_sigma**2)**(.5)

        if _k.x != 0:
            k_angle = arctan(_k.y/_k.x)
        else:
            k_angle = pi/2.

        alpha = arccos((((self.Ja+self.Jb*cos(3.0*_k.x*1.0))**2)/(self.Ja**2+self.Jb**2+2*self.Ja*self.Jb*cos(3.0*_k.x*1.0)))**(.5))
        print 'alpha is', alpha

        for s in _lattice:
            if s.indices[2] == 1:
                spinx = self.spinA_sigma*cos(dot(_k, s.pos) + k_angle)
                spiny = self.spinA_sigma*sin(dot(_k, s.pos )+ k_angle)
                spinz = spinA_zmag
            else:
                spinx = -self.spinB_sigma*cos(dot(_k, s.pos) + k_angle - _k.x*1.0 + alpha)
                spiny = -self.spinB_sigma*sin(dot(_k, s.pos) + k_angle - _k.x*1.0 + alpha)
                spinz = -spinB_zmag
                
            s.spin = vector(spinx,spiny,spinz)
            s.spin = norm(s.spin)

        if was_split == 1:
            was_split = 0
            self.t_sublattices()
        return _lattice


    def animate_spins(self, _lattice=[]):
        while 1:
            if self.pause_it == 1:
                self.paused = 1
            else:
                self.paused = 0
                
            if self.pause_it != 1 and self.paused != 1:        
                for s in _lattice:
                    s.torque = vector(0,0,0)

#                    print 'spin', s.indices[0]
                        
                    for nnx in range(2):
                        nspinx = s.nearx[nnx]

                        if (s.indices[0]+1)%2 == 1 and nnx == 0:
#                            print 'left Jb, spin', nspinx
                            s.torque = s.torque + self.Jb*cross(s.spin,_lattice[nspinx].spin)

                        if (s.indices[0]+1)%2 == 0 and nnx == 1:
#                            print 'right Jb, spin', nspinx
                            s.torque = s.torque + self.Jb*cross(s.spin,_lattice[nspinx].spin)

                            
                        if (s.indices[0]+1)%2 == 0 and nnx == 0:
#                            print 'left Ja, spin', nspinx
                            s.torque = s.torque + self.Ja*cross(s.spin,_lattice[nspinx].spin)
                            
                        if (s.indices[0]+1)%2 == 1 and nnx == 1:
#                            print 'right Ja, spin', nspinx
                            s.torque = s.torque + self.Ja*cross(s.spin,_lattice[nspinx].spin)

                    for nny in range(2):
                        nspiny = s.neary[nny]
                        s.torque = s.torque + self.Ja*cross(s.spin,_lattice[nspiny].spin)

#                    print 'end spin\n'


                if self.toggle_fix_sZ.get() == 1:
                    for s in _lattice:
                        s.torque.z = 0

                for s in _lattice:
                    s.spin = s.spin + s.torque*self.dt                              # add torque
                    if s.indices[2] == 1:
                        s.spinvec.axis = 2*self.spinvec_length*norm(vector(s.spin.x,s.spin.y,0))
                    else:
                        if abs(self.spin_mag_ratio) < .5:
                            s.spinvec.axis = 2*self.spinvec_length*.5*norm(vector(s.spin.x,s.spin.y,0))
                        else:
                            s.spinvec.axis = 2*self.spinvec_length*abs(self.spin_mag_ratio)*norm(vector(s.spin.x,s.spin.y,0))
                        
                    s.spinvec.pos = s.pos - s.spinvec.axis/2

                if self.toggle_norm_S.get() == 1:
                    for s in _lattice:
                        s.spin = norm(s.spin)                                      # ensure spin remains magnitude 1

                if self.toggle_torques.get() == 1:
                    for s in _lattice:
                        s.torquevec.axis = 40.0*s.torque
                        s.torquevec.pos = s.pos + s.spinvec.axis/2

                if self.toggle_colors.get() == 1:
                    for s in _lattice:
                        if mag(self.k) != 0:
                            if s.spin.z >= 0:
                                up = 1
                            else:
                                up = -1
                            if self.toggle_polarization.get() == 1:
                                s.spinvec.color = (.5+.5*dot(norm(vector(s.spin.x,s.spin.y,0)),norm(-self.k)) , 0 , .5+.5*dot(norm(vector(s.spin.x,s.spin.y,0)),norm(self.k)))
                            else:
                                s.spinvec.color = (.5+.5*up*dot(norm(vector(s.spin.x,s.spin.y,0)),norm(-self.k)) , 0 , .5+.5*up*dot(norm(vector(s.spin.x,s.spin.y,0)),norm(self.k)))
        rate(100)

    def initialize_circles(self, _lattice=[]):
        self.pause_it = 1
        while self.paused == 0:
            pass

        for s in _lattice:
            if s.indices[2] == 1:
                s.circle = curve(x=s.pos.x+(self.spinvec_length)*cos(arange(0,2*pi,.31)), y=s.pos.y+(self.spinvec_length)*sin(arange(0,2*pi,.31)), color=(.5,.5,.5))
            else:
                s.circle = curve(x=s.pos.x+(self.spinvec_length)*self.spin_mag_ratio*cos(arange(0,2*pi,.31)), y=s.pos.y+(self.spinvec_length)*self.spin_mag_ratio*sin(arange(0,2*pi,.31)), color=(.5,.5,.5))
        self.pause_it = 0
        return _lattice

    
    def set_gammaratio(self, gammaratio_string):

        # if one value of gamma
        # otherwise we will need a matrix of sigmas to start with.
        
        if string.atoi(gammaratio_string) != 0:
            self.spin_mag_ratio = string.atoi(gammaratio_string)/1000.0
            self.spins = self.initialize_spins(self.spins, self.k)
            self.pause_it = 0


    def set_kx(self, kx_string):
        if (self.k.y == 0.0 and string.atoi(kx_string) == 0.0):
            pass
        else:
            self.k.x = (self.unitCellSpinsx/self.unitCellSizex)*string.atoi(kx_string)*(2*pi)/(self.Ntotx*self.dx)
            
        self.setSpinMagRatioss(self.k)
        gammaratio_widget.set(self.spin_mag_ratio*1000.0)
        print 'wavenumbers in natural units', self.k.x, self.k.y
        self.spins = self.initialize_spins(self.spins, self.k)
        self.pause_it = 0

    def set_ky(self, ky_string):
        if (self.k.x == 0.0 and string.atoi(ky_string) == 0.0):
            pass
        else:
            self.k.y = (self.unitCellSpinsy/self.unitCellSizey)*string.atoi(ky_string)*(2*pi)/(self.Ntoty*self.dy)

        self.setSpinMagRatioss(self.k)
        gammaratio_widget.set(self.spin_mag_ratio*1000.0)
        print 'wavenumbers in natural units', self.k.x, self.k.y
        self.spins = self.initialize_spins(self.spins, self.k)
        self.pause_it = 0


    def t_sublattices(self):
        self.pause_it = 1
        while self.paused == 0:
            pass
        
        if self.toggle_sublattices == 0:
            self.toggle_sublattices = 1
            for s in self.spins:
                if s.indices[2] == 1:
                    s.pos = s.pos + vector((self.Ntotx*self.dx)+self.dx*2.,0,0)/2.
                else:
                    s.pos = s.pos - vector((self.Ntotx*self.dx)+self.dx*2.,0,0)/2.
        else:
            self.toggle_sublattices = 0
            for s in self.spins:
                if s.indices[2] == 1:
                    s.pos = s.pos - vector((self.Ntotx*self.dx)+self.dx*2.,0,0)/2.
                else:
                    s.pos = s.pos + vector((self.Ntotx*self.dx)+self.dx*2.,0,0)/2.
        self.pause_it = 0
        pass
        
                
    def t_torques(self):
        self.pause_it = 1
        while self.paused == 0:
            pass
        
        for s in self.spins:
            s.torquevec.visible = self.toggle_torques.get()
        self.pause_it = 0


    def t_circles(self):
        self.pause_it = 1
        while self.paused == 0:
            pass
        
        for s in self.spins:
            s.circle.visible = self.toggle_circles.get()
        self.pause_it = 0

        
    def t_colors(self):
        self.pause_it = 1
        while self.paused == 0:
            pass
        
        for s in self.spins:
            s.spinvec.color = (1,1,1)
        self.pause_it = 0


    def t_polarization(self):
        pass
        
    def t_norm_S(self):
        pass

    def t_fix_sZ(self):
        pass

    def set_dt(self, dt_string):
        self.dt = string.atoi(dt_string)/1000.0

    def reset(self):
        self.pause_it = 1
        while self.paused == 0:
            pass
        
        self.spins = self.initialize_spins(self.spins, self.k)
        self.pause_it = 0



##  Instance creation
        
tkr = Tk()
simu = Simulation(tkr)


##  TKR appearance and widget creation

tkr.wm_geometry(newGeometry="400x320+600+20")   
tkr.wm_title("Controls")


t_frame = Frame(tkr,relief=SUNKEN, borderwidth=0)
t_frame.pack(side=TOP, padx=0, pady=0, expand=1)

toggles = Frame(t_frame, relief=SUNKEN, borderwidth=1)
toggles.pack(side=LEFT, padx=10, pady=10, expand=1)

circles_widget = Checkbutton(toggles, text="Circles", variable=simu.toggle_circles, command=simu.t_circles)
simu.t_circles()
circles_widget.pack(fill=X,expand=1)

torque_widget = Checkbutton(toggles, text="Torques", variable=simu.toggle_torques, command=simu.t_torques)
simu.t_torques()
torque_widget.pack(fill=X,expand=1)

color_widget = Checkbutton(toggles, text="Colors", variable=simu.toggle_colors, command=simu.t_colors)
color_widget.toggle()
color_widget.pack(fill=X,expand=1)

polarization_widget = Checkbutton(toggles, text="ALT Color Polarization", variable=simu.toggle_polarization, command=simu.t_polarization)
polarization_widget.pack(fill=X,expand=1)

norm_S_widget = Checkbutton(toggles, text="Keep Spins Normalized", variable=simu.toggle_norm_S, command=simu.t_norm_S)
norm_S_widget.toggle()
norm_S_widget.pack(fill=X,expand=1)

fix_sZ_widget = Checkbutton(toggles, text="Fixed Spin Z", variable=simu.toggle_fix_sZ, command=simu.t_fix_sZ)
fix_sZ_widget.toggle()
fix_sZ_widget.pack(fill=X,expand=1)


r_frame = Frame(t_frame, relief=SUNKEN, borderwidth=1)
r_frame.pack(side=RIGHT, padx=10, pady=10, expand=1)

sublattices_widget = Button(r_frame, text="Sublattices", command=simu.t_sublattices)
sublattices_widget.pack(fill=X)

reset_widget = Button(r_frame, text="Reset", command=simu.reset)
reset_widget.pack(fill=X)


scales = Frame(tkr, relief=SUNKEN, borderwidth=1)
scales.pack(side=BOTTOM, padx=10, pady=10, expand=1)

kx_widget = Scale(scales, orient=HORIZONTAL, from_=-6, to=6, resolution=1.0, label="kx/2PiL", command=lambda str: simu.set_kx(str))
kx_widget.set(simu.k.x*simu.Ntotx*simu.dx/(2.*pi))
kx_widget.pack(side=LEFT)

ky_widget = Scale(scales, orient=VERTICAL, from_=6, to=-6, resolution=1.0, label="ky/2PiL", command=lambda str: simu.set_ky(str))
ky_widget.set(simu.k.y*simu.Ntoty*simu.dy/(2.*pi))
ky_widget.pack(side=LEFT)

dt_widget = Scale(scales, orient=VERTICAL, from_=200, to=0, label="dt", command=lambda str: simu.set_dt(str))
dt_widget.set(simu.dt*1000.0)
dt_widget.pack(side=RIGHT)

gammaratio_widget = Scale(scales, orient=VERTICAL, from_=1000, to=-1000, resolution=1, label="GR", command=lambda str: simu.set_gammaratio(str))
gammaratio_widget.set(simu.spin_mag_ratio*1000.0)
gammaratio_widget.pack(side=RIGHT)

##  Enter the main TKR loop, starting the interface and program
tkr.mainloop()



##Ntotx_widget = Scale(tkr, orient=VERTICAL, from_=100, to=0, label="Nx", command=lambda str: simu.set_Ntotx(str))
##Ntotx_widget.set(simu.Ntotx)
##Ntotx_widget.pack(side=LEFT)
##
##Ntoty_widget = Scale(tkr, orient=VERTICAL, from_=100, to=0, label="Ny", command=lambda str: simu.set_Ntoty(str))
##Ntoty_widget.set(simu.Ntoty)
##Ntoty_widget.pack(side=LEFT)

##    def set_angleratio(self,angleratio_string):
##        self.spinvec_angleratio = string.atoi(angleratio_string)

##angleratio_widget = Scale(tkr, orient=VERTICAL, from_=1/simu.spinA_sigma, to=1, label="TR", command=lambda str: simu.set_angleratio(str))
##angleratio_widget.set(simu.spinvec_angleratio)
##angleratio_widget.pack(side=RIGHT)
